# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mindspore quantum simulator evolution operator."""

import numpy as np
from mindquantum.ops import QubitOperator
from mindspore import Tensor
from mindspore.ops.primitive import PrimitiveWithInfer
from mindspore.ops.primitive import prim_attr_register
from mindspore._checkparam import Validator as validator
from mindspore.common import dtype as mstype
from mindquantum.gate import Hamiltonian
from ._check_qnn_input import _check_circuit
from ._check_qnn_input import _check_non_parameterized_circuit
from ._check_qnn_input import _check_type_or_iterable_type
from ._check_qnn_input import _check_parameters_of_circuit


class Evolution(PrimitiveWithInfer):
    r"""
    Inputs of this operation is generated by MindQuantum framework.

    Inputs:
        - **n_qubits** (int) - The qubit number of quantum simulator.
        - **param_names** (list[str]) - The parameters names.
        - **gate_names** (list[str]) - The name of each gate.
        - **gate_matrix** (list[list[list[list[float]]]]) - Real part and image part of the matrix of quantum gate.
        - **gate_obj_qubits** (list[list[int]]) - Object qubits of each gate.
        - **gate_ctrl_qubits** (list[list[int]]) - Control qubits of each gate.
        - **gate_params_names** (list[list[str]]) - Parameter names of each gate.
        - **gate_coeff** (list[list[float]]) - Coefficient of eqch parameter of each gate.
        - **gate_requires_grad** (list[list[bool]]) - Whether to calculate gradient of parameters of gates.
        - **hams_pauli_coeff** (list[list[float]]) - Coefficient of pauli words.
        - **hams_pauli_word** (list[list[list[str]]]) - Pauli words.
        - **hams_pauli_qubit** (list[list[list[int]]]) - The qubit that pauli matrix act on.

    Outputs:
        - **Quantum state** (Tensor) - The quantum state after evolution.

    Supported Platforms:
        ``CPU``
    """
    @prim_attr_register
    def __init__(self, n_qubits, param_names, gate_names, gate_matrix,
                 gate_obj_qubits, gate_ctrl_qubits, gate_params_names,
                 gate_coeff, gate_requires_grad, hams_pauli_coeff,
                 hams_pauli_word, hams_pauli_qubit):
        """Initialize Evolutino"""
        self.init_prim_io_names(inputs=['param_data'], outputs=['state'])
        self.n_qubits = n_qubits

    def check_shape_size(self, param_data):
        if len(param_data) != 1:
            raise ValueError("PQC input param_data should have dimension size \
equal to 1, but got {}.".format(len(param_data)))

    def infer_shape(self, param_data):
        self.check_shape_size(param_data)
        return [1 << self.n_qubits, 2]

    def infer_dtype(self, param_data):
        args = {'param_data': param_data}
        validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type,
                                                      self.name)
        return param_data

    def __call__(self, tmp=None):
        if tmp is None:
            if self.param_names:
                raise ValueError(
                    "Parameterized circuit shuold have parameter input.")
            tmp = Tensor(np.array([0]).astype(np.float32))
        state = super().__call__(tmp)
        state = state.asnumpy()
        state = state[:, 0] + state[:, 1] * 1j
        return state


def generate_evolution_operator(circuit, param_names=None, hams=None):
    """
    A method to generate a parameterized quantum circuit simulation operator.

    Args:
        circuit (Circuit): The whole circuit combined with
            encoder circuit and ansatz circuit, can be a parameterized circuit
            or a non parameterized circuit.
        param_names (list[str]): The list of parameter names, if None, than the
            circuit should be a non parameterized circuit, otherwise, param_names will
            be take from circuit. Default: None.
        hams (Union[Hamiltonian, list[Hamiltonian]]): The measurement
            hamiltonian. If None, than no hamiltonians will be applied on the
            final quantum state. Default: None.

    Returns:
        Evolution, A parameterized quantum circuit simulator operator supported by mindspore framework.

    Examples:
        >>> import numpy as np
        >>> from mindspore import Tensor
        >>> import mindquantum.gate as G
        >>> from mindquantum import Circuit
        >>> from mindquantum.nn import generate_evolution_operator
        >>> circ = Circuit(G.RX('a').on(0))
        >>> evol = generate_evolution_operator(circ, ['a'])
        >>> state = evol(Tensor(np.array([0.5]).astype(np.float32)))
        array([0.9689124+0.j        , 0.       -0.24740396j], dtype=complex64)
        >>> G.RX(0.5).matrix()[:, 0]
        array([0.96891242+0.j        , 0.        -0.24740396j])
    """
    if param_names is None:
        param_names = circuit.para_name
    if not param_names:
        _check_non_parameterized_circuit(circuit)
    else:
        _check_circuit(circuit, 'circuit')
        _check_parameters_of_circuit([], param_names, circuit)
    if hams is not None:
        _check_type_or_iterable_type(hams, Hamiltonian, 'Hamiltonian')
    if isinstance(hams, Hamiltonian):
        hams = [hams]
    if hams is None:
        ham_ms_data = Hamiltonian(QubitOperator()).mindspore_data()
    else:
        ham_ms_data = {}
        for ham in hams:
            for k, v in ham.mindspore_data().items():
                if k not in ham_ms_data:
                    ham_ms_data[k] = [v]
                else:
                    ham_ms_data[k].append(v)
    evol = Evolution(circuit.n_qubits,
                     param_names=param_names,
                     **circuit.mindspore_data(),
                     **ham_ms_data)
    return evol
